{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import PreTrainedTokenizerFast, GenerationConfig , PhiForCausalLM, pipeline\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型保存路径\n",
    "这里加载sft后的模型和dpo后的模型，看看两者输出的区别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_dir = './model_save/tokenizer/'\n",
    "dpo_model_save_dir = './model_save/dpo/'\n",
    "sft_model_save_dir = './model_save/sft/'\n",
    "max_seq_len = 320"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Phi-2 size: 193.7M parameters\n"
     ]
    }
   ],
   "source": [
    "tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)\n",
    "dpo_model = PhiForCausalLM.from_pretrained(dpo_model_save_dir).to(device)\n",
    "\n",
    "sft_mdodel = PhiForCausalLM.from_pretrained(sft_model_save_dir).to(device)\n",
    "\n",
    "model_size = sum(t.numel() for t in dpo_model.parameters())\n",
    "print(f\"Phi-2 size: {model_size / 1000**2:.1f}M parameters\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义两个text-generation的pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpo_pipe = pipeline(\n",
    "    \"text-generation\", \n",
    "    model=dpo_model, \n",
    "    device=device,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "sft_pipe = pipeline(\n",
    "    \"text-generation\", \n",
    "    model=sft_mdodel, \n",
    "    device=device,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model after sft output:\n",
      "\n",
      "##提问:\n",
      "感冒了要怎么办？\n",
      "##回答:\n",
      "感冒时最好休息，多喝水，避免吃辛辣刺激的食物，如辛辣食物，以及保持室内空气流通。\n",
      "\n",
      "==================\n",
      "\n",
      "model after dpo output:\n",
      "\n",
      "##提问:\n",
      "感冒了要怎么办？\n",
      "##回答:\n",
      "感冒是由病毒引起的，感冒一般由病毒引起，以下是一些常见感冒的方法：\n",
      "- 洗手，特别是在接触其他人或物品后。\n",
      "- 咳嗽或打喷嚏时用纸巾或手肘遮住口鼻。\n",
      "- 触摸面部、鼻子和眼睛。\n",
      "- 喝充足的水，特别是在感冒季节。\n",
      "- 喝一些温水或鸡汤，以帮助身体降温。\n",
      "- 如果感冒症状严重，如发热或咳嗽，可能需要就医。\n"
     ]
    }
   ],
   "source": [
    "txt = '感冒了要怎么办？'\n",
    "prompt = f\"##提问:\\n{txt}\\n##回答:\\n\"\n",
    "sft_outputs = sft_pipe(prompt, num_return_sequences=1, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)\n",
    "dpo_outputs = dpo_pipe(prompt, num_return_sequences=1, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "\n",
    "print(f\"model after sft output:\\n\\n{sft_outputs[0]['generated_text']}\")\n",
    "print('\\n==================\\n')\n",
    "print(f\"model after dpo output:\\n\\n{dpo_outputs[0]['generated_text']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 不用pipeline组件，使用greedy search方法手动生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# greedy search\n",
    "gen_conf = GenerationConfig(\n",
    "    num_beams=1,\n",
    "    do_sample=False,\n",
    "    max_length=320,\n",
    "    max_new_tokens=256,\n",
    "    no_repeat_ngram_size=4,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "# top k\n",
    "top_k_gen_conf = GenerationConfig(\n",
    "    do_sample=True,\n",
    "    top_k=100,\n",
    "    max_length=320,\n",
    "    max_new_tokens=256,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "# top p\n",
    "top_p_gen_conf = GenerationConfig(\n",
    "    do_sample=True,\n",
    "    top_k=0,\n",
    "    top_p=0.95,\n",
    "    max_length=320,\n",
    "    max_new_tokens=256,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "##提问:\n",
      "感冒了要怎么办？\n",
      "##回答:\n",
      "感冒是由病毒引起的，感冒一般由病毒引起，以下是一些常见感冒的方法：\n",
      "- 洗手，特别是在接触其他人或物品后。\n",
      "- 咳嗽或打喷嚏时用纸巾或手肘遮住口鼻。\n",
      "- 用手触摸口鼻，特别是喉咙和鼻子。\n",
      "- 如果咳嗽或打喷嚏，可以用纸巾或手绢来遮住口鼻，但要远离其他人。\n",
      "- 如果你感冒了，最好不要触摸自己的眼睛、鼻子和嘴巴。\n",
      "- 在感冒期间，最好保持充足的水分和休息，以缓解身体的疲劳。\n",
      "- 如果您已经感冒了，可以喝一些温水或盐水来补充体液。\n",
      "- 另外，如果感冒了，建议及时就医。\n"
     ]
    }
   ],
   "source": [
    "tokend = tokenizer.encode_plus(text=prompt)\n",
    "input_ids, attention_mask = torch.LongTensor([tokend.input_ids]).to(device), torch.LongTensor([tokend.attention_mask]).to(device)\n",
    "\n",
    "outputs = dpo_model.generate(\n",
    "    inputs=input_ids,\n",
    "    attention_mask=attention_mask,\n",
    "    generation_config=gen_conf,\n",
    ")\n",
    "\n",
    "outs = tokenizer.decode(outputs[0].cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True,)\n",
    "print(outs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 计算困惑度（perplexity）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load\n",
    "perplexity = load(\"perplexity\", module_type=\"metric\")\n",
    "dpo_model_save_dir = './model_save/dpo/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "429d58cc576a45a2a46f8185380b1b6b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = perplexity.compute(predictions=['生成一个关于食品的描述，包括其因素、口感和用途。', '在感冒期间，最好保持充足的水分和休息，以缓解身体的疲劳。'], add_start_token=False, model_id=dpo_model_save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [159.40676879882812, 97.79098510742188], 'mean_perplexity': 128.598876953125}\n"
     ]
    }
   ],
   "source": [
    "print(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
